/*
    MIT License

    Copyright (c) 2021 Zhepei Wang (wangzhepei@live.com)

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to
    deal in the Software without restriction, including without limitation the
    rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
    sell copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
    IN THE SOFTWARE.
*/

#ifndef MINCO_TRAJECTORY_FLATNESS_HPP
#define MINCO_TRAJECTORY_FLATNESS_HPP

#include <Eigen/Eigen>

#include <cmath>
#include <concepts>
#include <stdexcept>
#include <string>
#include <utility>

#include <yaml-cpp/yaml.h>

namespace minco::flatness
{

template <typename T>
concept ForwardQuery = requires(T value) {
    { value.velocity } -> std::convertible_to<Eigen::Vector3d>;
    { value.acceleration } -> std::convertible_to<Eigen::Vector3d>;
    { value.jerk } -> std::convertible_to<Eigen::Vector3d>;
};

template <typename T>
concept ForwardResult = requires(T value) {
    { value.thrust } -> std::convertible_to<double>;
    { value.quaternion } -> std::convertible_to<Eigen::Vector4d>;
    { value.angular_velocity } -> std::convertible_to<Eigen::Vector3d>;
};

template <typename T>
concept BackwardQuery = requires(T value) {
    { value.position_gradient } -> std::convertible_to<Eigen::Vector3d>;
    { value.velocity_gradient } -> std::convertible_to<Eigen::Vector3d>;
    { value.thrust_gradient } -> std::convertible_to<double>;
    { value.quaternion_gradient } -> std::convertible_to<Eigen::Vector4d>;
    { value.angular_velocity_gradient } -> std::convertible_to<Eigen::Vector3d>;
};

template <typename T>
concept BackwardResult = requires(T value) {
    { value.position_total_gradient } -> std::convertible_to<Eigen::Vector3d>;
    { value.velocity_total_gradient } -> std::convertible_to<Eigen::Vector3d>;
    { value.acceleration_total_gradient } -> std::convertible_to<Eigen::Vector3d>;
    { value.jerk_total_gradient } -> std::convertible_to<Eigen::Vector3d>;
};

template <typename Model>
concept FlatnessModel = requires(Model model, const typename Model::ConfigType &config,
                                 const typename Model::ForwardQuery &forward_query,
                                 const typename Model::BackwardQuery &backward_query) {
    ForwardQuery<typename Model::ForwardQuery>;
    ForwardResult<typename Model::ForwardResult>;
    BackwardQuery<typename Model::BackwardQuery>;
    BackwardResult<typename Model::BackwardResult>;

    { model.configure(config) } -> std::same_as<void>;
    { model.configure_from_file(std::declval<std::string>()) } -> std::same_as<void>;
    { model.forward(forward_query) } -> std::same_as<typename Model::ForwardResult>;
    { model.backward(backward_query) } -> std::same_as<typename Model::BackwardResult>;
};

struct DefaultConfig
{
    double mass            = 1.0;
    double gravity         = 9.81;
    double horizontal_drag = 0.0;
    double vertical_drag   = 0.0;
    double parasitic_drag  = 0.0;
    double speed_smooth    = 1.0e-3;
};

struct DefaultForwardQuery
{
    Eigen::Vector3d velocity      = Eigen::Vector3d::Zero();
    Eigen::Vector3d acceleration  = Eigen::Vector3d::Zero();
    Eigen::Vector3d jerk          = Eigen::Vector3d::Zero();
    double          yaw           = 0.0;
    double          yaw_rate      = 0.0;
};

struct DefaultForwardResult
{
    double          thrust           = 0.0;
    Eigen::Vector4d quaternion       = Eigen::Vector4d::Zero();
    Eigen::Vector3d angular_velocity = Eigen::Vector3d::Zero();
};

struct DefaultBackwardQuery
{
    Eigen::Vector3d position_gradient       = Eigen::Vector3d::Zero();
    Eigen::Vector3d velocity_gradient       = Eigen::Vector3d::Zero();
    double          thrust_gradient         = 0.0;
    Eigen::Vector4d quaternion_gradient     = Eigen::Vector4d::Zero();
    Eigen::Vector3d angular_velocity_gradient = Eigen::Vector3d::Zero();
};

struct DefaultBackwardResult
{
    Eigen::Vector3d position_total_gradient     = Eigen::Vector3d::Zero();
    Eigen::Vector3d velocity_total_gradient     = Eigen::Vector3d::Zero();
    Eigen::Vector3d acceleration_total_gradient = Eigen::Vector3d::Zero();
    Eigen::Vector3d jerk_total_gradient         = Eigen::Vector3d::Zero();
    double          yaw_total_gradient          = 0.0;
    double          yaw_rate_total_gradient     = 0.0;
};

class FlatnessMap
{
   public:
    using ConfigType       = DefaultConfig;
    using ForwardQuery     = DefaultForwardQuery;
    using ForwardResult    = DefaultForwardResult;
    using BackwardQuery    = DefaultBackwardQuery;
    using BackwardResult   = DefaultBackwardResult;

    static constexpr const char *kDefaultConfigPath =
        "config/default_flatness_config.yaml";

    FlatnessMap() = default;

    inline void configure(const ConfigType &config)
    {
        config_ = config;
        reset(config_.mass, config_.gravity, config_.horizontal_drag,
              config_.vertical_drag, config_.parasitic_drag,
              config_.speed_smooth);
    }

    inline const ConfigType &config() const { return config_; }

    inline void configure_from_file(const std::string &file_path = kDefaultConfigPath)
    {
        const std::string path = file_path.empty() ? kDefaultConfigPath : file_path;
        try
        {
            YAML::Node root = YAML::LoadFile(path);
            const YAML::Node node = root["flatness"].IsDefined()
                                         ? root["flatness"]
                                         : root;

            ConfigType parsed = config_;
            if (node["mass"]) parsed.mass = node["mass"].as<double>();
            if (node["gravity"]) parsed.gravity = node["gravity"].as<double>();
            if (node["horizontal_drag"])
                parsed.horizontal_drag = node["horizontal_drag"].as<double>();
            if (node["vertical_drag"])
                parsed.vertical_drag = node["vertical_drag"].as<double>();
            if (node["parasitic_drag"])
                parsed.parasitic_drag = node["parasitic_drag"].as<double>();
            if (node["speed_smooth"])
                parsed.speed_smooth = node["speed_smooth"].as<double>();

            configure(parsed);
        }
        catch (const YAML::Exception &ex)
        {
            throw std::runtime_error("Failed to load flatness config from " +
                                     path + ": " + ex.what());
        }
    }

    inline ForwardResult forward(const ForwardQuery &query)
    {
        ForwardResult result;
        forward(query.velocity, query.acceleration, query.jerk, query.yaw,
                query.yaw_rate, result.thrust, result.quaternion,
                result.angular_velocity);
        return result;
    }

    inline BackwardResult backward(const BackwardQuery &query) const
    {
        BackwardResult result;
        backward(query.position_gradient, query.velocity_gradient,
                 query.thrust_gradient, query.quaternion_gradient,
                 query.angular_velocity_gradient,
                 result.position_total_gradient,
                 result.velocity_total_gradient,
                 result.acceleration_total_gradient,
                 result.jerk_total_gradient, result.yaw_total_gradient,
                 result.yaw_rate_total_gradient);
        return result;
    }

    // Legacy interface retained for backward compatibility with existing code.
    inline void reset(const double &vehicle_mass, const double &gravitational_acceleration,
                      const double &horizontal_drag_coeff, const double &vertical_drag_coeff,
                      const double &parasitic_drag_coeff, const double &speed_smooth_factor)
    {
        config_.mass            = vehicle_mass;
        config_.gravity         = gravitational_acceleration;
        config_.horizontal_drag = horizontal_drag_coeff;
        config_.vertical_drag   = vertical_drag_coeff;
        config_.parasitic_drag  = parasitic_drag_coeff;
        config_.speed_smooth    = speed_smooth_factor;

        mass = vehicle_mass;
        grav = gravitational_acceleration;
        dh   = horizontal_drag_coeff;
        dv   = vertical_drag_coeff;
        cp   = parasitic_drag_coeff;
        veps = speed_smooth_factor;
    }

    inline void forward(const Eigen::Vector3d &vel, const Eigen::Vector3d &acc,
                        const Eigen::Vector3d &jer, const double &psi,
                        const double &dpsi, double &thr, Eigen::Vector4d &quat,
                        Eigen::Vector3d &omg)
    {
        double w0, w1, w2, dw0, dw1, dw2;

        v0          = vel(0);
        v1          = vel(1);
        v2          = vel(2);
        a0          = acc(0);
        a1          = acc(1);
        a2          = acc(2);
        cp_term     = sqrt(v0 * v0 + v1 * v1 + v2 * v2 + veps);
        w_term      = 1.0 + cp * cp_term;
        w0          = w_term * v0;
        w1          = w_term * v1;
        w2          = w_term * v2;
        dh_over_m   = dh / mass;
        zu0         = a0 + dh_over_m * w0;
        zu1         = a1 + dh_over_m * w1;
        zu2         = a2 + dh_over_m * w2 + grav;
        zu_sqr0     = zu0 * zu0;
        zu_sqr1     = zu1 * zu1;
        zu_sqr2     = zu2 * zu2;
        zu01        = zu0 * zu1;
        zu12        = zu1 * zu2;
        zu02        = zu0 * zu2;
        zu_sqr_norm = zu_sqr0 + zu_sqr1 + zu_sqr2;
        zu_norm     = sqrt(zu_sqr_norm);
        z0          = zu0 / zu_norm;
        z1          = zu1 / zu_norm;
        z2          = zu2 / zu_norm;
        ng_den      = zu_sqr_norm * zu_norm;
        ng00        = (zu_sqr1 + zu_sqr2) / ng_den;
        ng01        = -zu01 / ng_den;
        ng02        = -zu02 / ng_den;
        ng11        = (zu_sqr0 + zu_sqr2) / ng_den;
        ng12        = -zu12 / ng_den;
        ng22        = (zu_sqr0 + zu_sqr1) / ng_den;
        v_dot_a     = v0 * a0 + v1 * a1 + v2 * a2;
        dw_term     = cp * v_dot_a / cp_term;
        dw0         = w_term * a0 + dw_term * v0;
        dw1         = w_term * a1 + dw_term * v1;
        dw2         = w_term * a2 + dw_term * v2;
        dz_term0    = jer(0) + dh_over_m * dw0;
        dz_term1    = jer(1) + dh_over_m * dw1;
        dz_term2    = jer(2) + dh_over_m * dw2;
        dz0         = ng00 * dz_term0 + ng01 * dz_term1 + ng02 * dz_term2;
        dz1         = ng01 * dz_term0 + ng11 * dz_term1 + ng12 * dz_term2;
        dz2         = ng02 * dz_term0 + ng12 * dz_term1 + ng22 * dz_term2;
        f_term0     = mass * a0 + dv * w0;
        f_term1     = mass * a1 + dv * w1;
        f_term2     = mass * (a2 + grav) + dv * w2;
        thr         = z0 * f_term0 + z1 * f_term1 + z2 * f_term2;
        tilt_den    = sqrt(2.0 * (1.0 + z2));
        tilt0       = 0.5 * tilt_den;
        tilt1       = -z1 / tilt_den;
        tilt2       = z0 / tilt_den;
        c_half_psi  = cos(0.5 * psi);
        s_half_psi  = sin(0.5 * psi);
        quat(0)     = tilt0 * c_half_psi;
        quat(1)     = tilt1 * c_half_psi + tilt2 * s_half_psi;
        quat(2)     = tilt2 * c_half_psi - tilt1 * s_half_psi;
        quat(3)     = tilt0 * s_half_psi;
        c_psi       = cos(psi);
        s_psi       = sin(psi);
        omg_den     = z2 + 1.0;
        omg_term    = dz2 / omg_den;
        omg(0)      = dz0 * s_psi - dz1 * c_psi -
                 (z0 * s_psi - z1 * c_psi) * omg_term;
        omg(1) = dz0 * c_psi + dz1 * s_psi -
                 (z0 * c_psi + z1 * s_psi) * omg_term;
        omg(2) = (z1 * dz0 - z0 * dz1) / omg_den + dpsi;
    }

   inline void backward(const Eigen::Vector3d &pos_grad,
                         const Eigen::Vector3d &vel_grad, const double &thr_grad,
                         const Eigen::Vector4d &quat_grad,
                         const Eigen::Vector3d &omg_grad,
                         Eigen::Vector3d &pos_total_grad,
                         Eigen::Vector3d &vel_total_grad,
                         Eigen::Vector3d &acc_total_grad,
                         Eigen::Vector3d &jer_total_grad, double &psi_total_grad,
                         double &dpsi_total_grad) const
    {
        double w0b, w1b, w2b, dw0b, dw1b, dw2b;
        double z0b, z1b, z2b, dz0b, dz1b, dz2b;
        double v_sqr_normb, cp_termb, w_termb;
        double zu_sqr_normb, zu_normb, zu0b, zu1b, zu2b;
        double zu_sqr0b, zu_sqr1b, zu_sqr2b, zu01b, zu12b, zu02b;
        double ng00b, ng01b, ng02b, ng11b, ng12b, ng22b, ng_denb;
        double dz_term0b, dz_term1b, dz_term2b, f_term0b, f_term1b, f_term2b;
        double tilt_denb, tilt0b, tilt1b, tilt2b, head0b, head3b;
        double cpsib, spsib, omg_denb, omg_termb;
        double tempb, tilt_den_sqr;

        tilt0b = s_half_psi * (quat_grad(3)) + c_half_psi * (quat_grad(0));
        head3b = tilt0 * (quat_grad(3)) + tilt2 * (quat_grad(1)) -
                 tilt1 * (quat_grad(2));
        tilt2b = c_half_psi * (quat_grad(2)) + s_half_psi * (quat_grad(1));
        head0b = tilt2 * (quat_grad(2)) + tilt1 * (quat_grad(1)) +
                 tilt0 * (quat_grad(0));
        tilt1b = c_half_psi * (quat_grad(1)) - s_half_psi * (quat_grad(2));
        tilt_den_sqr = tilt_den * tilt_den;
        tilt_denb =
            (z1 * tilt1b - z0 * tilt2b) / tilt_den_sqr + 0.5 * tilt0b;
        omg_termb = -((z0 * c_psi + z1 * s_psi) * (omg_grad(1))) -
                    (z0 * s_psi - z1 * c_psi) * (omg_grad(0));
        tempb           = omg_grad(2) / omg_den;
        dpsi_total_grad = omg_grad(2);
        z1b             = dz0 * tempb;
        dz0b = z1 * tempb + c_psi * (omg_grad(1)) + s_psi * (omg_grad(0));
        z0b  = -(dz1 * tempb);
        dz1b = s_psi * (omg_grad(1)) - z0 * tempb - c_psi * (omg_grad(0));
        omg_denb = -((z1 * dz0 - z0 * dz1) * tempb / omg_den) -
                   dz2 * omg_termb / (omg_den * omg_den);
        tempb = -(omg_term * (omg_grad(1)));
        cpsib = dz0 * (omg_grad(1)) + z0 * tempb;
        spsib = dz1 * (omg_grad(1)) + z1 * tempb;
        z0b += c_psi * tempb;
        z1b += s_psi * tempb;
        tempb = -(omg_term * (omg_grad(0)));
        spsib += dz0 * (omg_grad(0)) + z0 * tempb;
        cpsib += -dz1 * (omg_grad(0)) - z1 * tempb;
        z0b += s_psi * tempb + tilt2b / tilt_den + f_term0 * (thr_grad);
        z1b += -c_psi * tempb - tilt1b / tilt_den + f_term1 * (thr_grad);
        dz2b = omg_termb / omg_den;
        z2b  = omg_denb + tilt_denb / tilt_den + f_term2 * (thr_grad);
        psi_total_grad = c_psi * spsib + 0.5 * c_half_psi * head3b -
                         s_psi * cpsib - 0.5 * s_half_psi * head0b;
        f_term0b          = z0 * (thr_grad);
        f_term1b          = z1 * (thr_grad);
        f_term2b          = z2 * (thr_grad);
        ng02b             = dz_term0 * dz2b + dz_term2 * dz0b;
        dz_term0b         = ng02 * dz2b + ng01 * dz1b + ng00 * dz0b;
        ng12b             = dz_term1 * dz2b + dz_term2 * dz1b;
        dz_term1b         = ng12 * dz2b + ng11 * dz1b + ng01 * dz0b;
        ng22b             = dz_term2 * dz2b;
        dz_term2b         = ng22 * dz2b + ng12 * dz1b + ng02 * dz0b;
        ng01b             = dz_term0 * dz1b + dz_term1 * dz0b;
        ng11b             = dz_term1 * dz1b;
        ng00b             = dz_term0 * dz0b;
        jer_total_grad(2) = dz_term2b;
        dw2b              = dh_over_m * dz_term2b;
        jer_total_grad(1) = dz_term1b;
        dw1b              = dh_over_m * dz_term1b;
        jer_total_grad(0) = dz_term0b;
        dw0b              = dh_over_m * dz_term0b;
        tempb = cp * (v2 * dw2b + v1 * dw1b + v0 * dw0b) / cp_term;
        acc_total_grad(2) = mass * f_term2b + w_term * dw2b + v2 * tempb;
        acc_total_grad(1) = mass * f_term1b + w_term * dw1b + v1 * tempb;
        acc_total_grad(0) = mass * f_term0b + w_term * dw0b + v0 * tempb;
        vel_total_grad(2) = dw_term * dw2b + a2 * tempb;
        vel_total_grad(1) = dw_term * dw1b + a1 * tempb;
        vel_total_grad(0) = dw_term * dw0b + a0 * tempb;
        cp_termb          = -(v_dot_a * tempb / cp_term);
        tempb             = ng22b / ng_den;
        zu_sqr0b          = tempb;
        zu_sqr1b          = tempb;
        ng_denb           = -((zu_sqr0 + zu_sqr1) * tempb / ng_den);
        zu12b             = -(ng12b / ng_den);
        tempb             = ng11b / ng_den;
        ng_denb += zu12 * ng12b / (ng_den * ng_den) -
                   (zu_sqr0 + zu_sqr2) * tempb / ng_den;
        zu_sqr0b += tempb;
        zu_sqr2b = tempb;
        zu02b    = -(ng02b / ng_den);
        zu01b    = -(ng01b / ng_den);
        tempb    = ng00b / ng_den;
        ng_denb += zu02 * ng02b / (ng_den * ng_den) +
                   zu01 * ng01b / (ng_den * ng_den) -
                   (zu_sqr1 + zu_sqr2) * tempb / ng_den;
        zu_normb = zu_sqr_norm * ng_denb -
                   (zu2 * z2b + zu1 * z1b + zu0 * z0b) / zu_sqr_norm;
        zu_sqr_normb = zu_norm * ng_denb + zu_normb / (2.0 * zu_norm);
        tempb += zu_sqr_normb;
        zu_sqr1b += tempb;
        zu_sqr2b += tempb;
        zu2b =
            z2b / zu_norm + zu0 * zu02b + zu1 * zu12b + 2 * zu2 * zu_sqr2b;
        w2b = dv * f_term2b + dh_over_m * zu2b;
        zu1b =
            z1b / zu_norm + zu2 * zu12b + zu0 * zu01b + 2 * zu1 * zu_sqr1b;
        w1b = dv * f_term1b + dh_over_m * zu1b;
        zu_sqr0b += zu_sqr_normb;
        zu0b =
            z0b / zu_norm + zu2 * zu02b + zu1 * zu01b + 2 * zu0 * zu_sqr0b;
        w0b     = dv * f_term0b + dh_over_m * zu0b;
        w_termb = a2 * dw2b + a1 * dw1b + a0 * dw0b + v2 * w2b + v1 * w1b +
                  v0 * w0b;
        acc_total_grad(2) += zu2b;
        acc_total_grad(1) += zu1b;
        acc_total_grad(0) += zu0b;
        cp_termb += cp * w_termb;
        v_sqr_normb = cp_termb / (2.0 * cp_term);
        vel_total_grad(2) += w_term * w2b + 2 * v2 * v_sqr_normb + vel_grad(2);
        vel_total_grad(1) += w_term * w1b + 2 * v1 * v_sqr_normb + vel_grad(1);
        vel_total_grad(0) += w_term * w0b + 2 * v0 * v_sqr_normb + vel_grad(0);
        pos_total_grad(2) = pos_grad(2);
        pos_total_grad(1) = pos_grad(1);
        pos_total_grad(0) = pos_grad(0);
    }

   private:
    ConfigType config_{};

    double mass = 1.0;
    double grav = 9.81;
    double dh   = 0.0;
    double dv   = 0.0;
    double cp   = 0.0;
    double veps = 1.0e-3;

    double v0 = 0.0, v1 = 0.0, v2 = 0.0, a0 = 0.0, a1 = 0.0, a2 = 0.0, v_dot_a = 0.0;
    double z0 = 0.0, z1 = 0.0, z2 = 0.0, dz0 = 0.0, dz1 = 0.0, dz2 = 0.0;
    double cp_term = 0.0, w_term = 0.0, dh_over_m = 0.0;
    double zu_sqr_norm = 0.0, zu_norm = 0.0, zu0 = 0.0, zu1 = 0.0, zu2 = 0.0;
    double zu_sqr0 = 0.0, zu_sqr1 = 0.0, zu_sqr2 = 0.0, zu01 = 0.0, zu12 = 0.0, zu02 = 0.0;
    double ng00 = 0.0, ng01 = 0.0, ng02 = 0.0, ng11 = 0.0, ng12 = 0.0, ng22 = 0.0, ng_den = 0.0;
    double dw_term = 0.0, dz_term0 = 0.0, dz_term1 = 0.0, dz_term2 = 0.0;
    double f_term0 = 0.0, f_term1 = 0.0, f_term2 = 0.0;
    double tilt_den = 0.0, tilt0 = 0.0, tilt1 = 0.0, tilt2 = 0.0, c_half_psi = 0.0, s_half_psi = 0.0;
    double c_psi = 0.0, s_psi = 0.0, omg_den = 0.0, omg_term = 0.0;

};

}  // namespace minco::flatness

#endif  // MINCO_TRAJECTORY_FLATNESS_HPP
